from langchain.chains import create_sql_query_chain
from langchain_community.utilities import SQLDatabase

from model.my_chat_model import ChatModel
import os


# 测试生成mysql语句
def test_sql():
    # 1 获取大模型
    chat = ChatModel()
    llm = chat.get_line_model()
    # 2 创建数据库链接
    db = SQLDatabase.from_uri(
        "mysql+pymysql://root:root@localhost:3306/ai",
        include_tables=["user_info"],
    )
    # 3 创建链
    chain = create_sql_query_chain(llm, db)
    # 4 提问
    question = f"请根据条件用户名是张三和密码是1查询用户信息"
    sql = chain.invoke({"question": question})
    print(sql)

    if "```sql" in sql:
        sql = sql.split("```sql")[1].split("```")[0]
    print(sql)
    # sql = "SELECT `user_id`, `user_name`, `email` FROM `user_info` WHERE `user_name` = '张三' AND `user_pwd` = '1234' LIMIT 5;"
    rs = db.run(sql)
    print(f"rs={rs}")


if __name__ == '__main__':
    test_sql()
